Skip to content

Conversation

@hjh0119
Copy link
Collaborator

@hjh0119 hjh0119 commented Sep 11, 2025

Optimize weight synchronization between the training model and the inference engine (vLLM):

LoRA

  1. Synchronize/load only the trained adapter weights.(both colocate / server mode), for server mode, use --vllm_enable_lora true in rollout
  2. In server mode, transmit flattened adapter weights to reduce communication overhead of model parameters.

FULL

  1. Remove the original per-tensor synchronization logic and adopt a bucketing strategy to reduce redundant communication requests and overhead, especially for MoE models (which have more tensors than dense models).‘
  2. Removed per-tensor gather and reverted to using batched gather; as a result, move_model_batches now works with full-parameter training.

Update: Built-in Accuracy Reward

  • Fixed several cases where accuracy could not be correctly evaluated
  • Added corresponding unit tests.

@hjh0119 hjh0119 changed the title [grpo] Optimize LoRA training vLLM weight synchronization [WIP] Optimize LoRA training vLLM weight synchronization Sep 11, 2025
@hjh0119 hjh0119 marked this pull request as ready for review September 11, 2025 09:27
@hjh0119
Copy link
Collaborator Author

hjh0119 commented Sep 11, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an optimization for LoRA training with vLLM by enabling in-memory weight synchronization using flattened tensors. This avoids disk I/O and should improve training speed. The changes involve adding new arguments, new protocol definitions, and new methods in the rollout engine and GRPO trainer. A key part of the implementation is monkey-patching vLLM to support loading LoRA adapters from tensors. The overall approach is sound, but there are a few areas that need attention, such as ensuring deterministic adapter selection, cleaning up commented-out code, and addressing TODO comments.

@hjh0119 hjh0119 changed the title [WIP] Optimize LoRA training vLLM weight synchronization Optimize LoRA training vLLM weight synchronization Sep 12, 2025
@hjh0119
Copy link
Collaborator Author

hjh0119 commented Sep 12, 2025

Qwen2.5-VL-7B-Instruct, server mode, tp=2, dp=2 → 10× speed-up

image

@hjh0119
Copy link
Collaborator Author

hjh0119 commented Oct 13, 2025

Qwen2.5-VL-32B-Instruct(LoRA Training) Reduced from 4s to 1s.

lora+ indicates that only the LoRA adapter weights are synchronized, and the parameter --vllm_enable_lora true is used on the rollout side.

image

@hjh0119
Copy link
Collaborator Author

hjh0119 commented Oct 17, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations for weight synchronization in GRPO training with vLLM, alongside a fix for the accuracy reward calculation. For LoRA training, it cleverly syncs only the flattened adapter weights to reduce communication overhead, using a monkey-patch to allow vLLM to accept in-memory tensors. For full-parameter training, it implements a bucketing strategy to group tensors and minimize synchronization calls, which is especially beneficial for MoE models. The changes are extensive but well-structured, with corresponding updates to documentation and example scripts. I've identified one high-severity issue related to a mismatch in LoRA adapter naming that could prevent the adapter from being correctly applied during inference. Overall, this is a very strong contribution that should yield substantial performance improvements.

if lora_int_ids:
# since max_lora = 1, pick the first lora
adapter_request = LoRARequest(
lora_name=f'lora_{lora_int_ids[0]}',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There appears to be a mismatch in how the lora_name is constructed here compared to when the adapter is added. When adding the adapter (in both colocate and server modes), the lora_name is set to the string representation of the integer ID (e.g., f'{lora_int_id}'). However, here you are prefixing it with 'lora_'. This will likely cause vLLM to fail to find the requested adapter during inference. To fix this, the prefix should be removed to ensure consistency.

Suggested change
lora_name=f'lora_{lora_int_ids[0]}',
lora_name=str(lora_int_ids[0]),

@hjh0119 hjh0119 changed the title [grpo] Optimize vLLM weight synchronization [grpo] Optimize vLLM weight synchronization & update buitin accuracy reward Oct 17, 2025
@hjh0119
Copy link
Collaborator Author

hjh0119 commented Oct 17, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations for weight synchronization with vLLM, for both LoRA and full-parameter training. For LoRA, it enables syncing only adapter weights and uses flattened tensors to reduce overhead. For full training, it implements a bucketing strategy for more efficient weight transfer. The built-in accuracy reward function has also been improved to handle more cases correctly, and unit tests have been added.

My review focuses on ensuring the new synchronization logic is consistent and robust. I've suggested a refactoring to unify the parameter bucketing logic for full-parameter training, making it more maintainable and correct. I also pointed out a minor inconsistency in the documentation between the English and Chinese versions. Overall, these are great improvements to the training efficiency and correctness.

@hjh0119
Copy link
Collaborator Author

hjh0119 commented Oct 17, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations for vLLM weight synchronization in GRPO training, covering both LoRA and full-parameter scenarios. The use of flattened tensors and a bucketing strategy to reduce communication overhead is a solid improvement. Additionally, the built-in accuracy reward mechanism has been updated to handle more cases correctly, and new unit tests have been added to ensure its robustness. The documentation is also thoroughly updated to reflect these new features. The code is well-structured, particularly the refactoring of the weight synchronization logic. I have a few minor suggestions to enhance code consistency and maintainability.

Comment on lines 90 to 92
"""
Receives updated weights from the client process and updates the named parameter in the model.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The docstring here is slightly misleading. It states that the method "updates the named parameter in the model," but it actually receives and applies a full flattened LoRA adapter. A more precise description would improve clarity.

Suggested change
"""
Receives updated weights from the client process and updates the named parameter in the model.
"""
"""
Receives and applies a flattened LoRA adapter to the model.
"""

Comment on lines +685 to 688
state_dict = prefix_removed if self.rollout_enable_lora else {
k.replace('.base_layer', ''): v
for k, v in prefix_removed.items()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There's an inconsistency in how .base_layer is removed from parameter names. Here, k.replace('.base_layer', '') is used, while the split_batches method uses re.sub(r'\.base_layer\.', '.', name). Using re.sub is more robust as it specifically targets .base_layer. as a whole component in the name, preventing accidental replacements if .base_layer appears as part of another name. Using re.sub here would improve consistency and robustness.

Suggested change
state_dict = prefix_removed if self.rollout_enable_lora else {
k.replace('.base_layer', ''): v
for k, v in prefix_removed.items()
}
state_dict = prefix_removed if self.rollout_enable_lora else {
re.sub(r'\.base_layer\.', '.', k): v
for k, v in prefix_removed.items()
}

return gather_if_zero3


def patch_vllm_load_adapter():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function monkey-patches vLLM's _load_adapter to support loading LoRA adapters from in-memory tensors. While this is a clever solution, monkey-patching can be fragile and might break with future updates to the vLLM library. To improve maintainability, it would be beneficial to add a comment in the docstring specifying which version(s) of vLLM this patch is compatible with. This will make it easier to track and update when vLLM is upgraded.

@hjh0119 hjh0119 merged commit 8eda1d3 into modelscope:main Oct 17, 2025
1 of 2 checks passed
@hjh0119 hjh0119 deleted the lora+ branch October 17, 2025 09:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants